import os
import paddle
import numpy as np


BATCH_SIZE=20

train_dataset = paddle.text.datasets.UCIHousing(mode='train')
valid_dataset = paddle.text.datasets.UCIHousing(mode='test')

#用于训练的数据加载器，每次随机读取批次大小的数据，剩余不满足批大小的数据丢弃
train_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

#用于测试的数据加载器，每次随机读取批次大小的数据
valid_loader = paddle.io.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)

#用于打印，查看uci_housing数据
print(train_dataset[0])

# 输入数据形状为[13]，输出形状[1]
net = paddle.nn.Linear(13, 1)

optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=net.parameters())

import matplotlib.pyplot as plt


iter = 0
iters = []
train_costs = []

def draw_train_process(iters, train_costs):
    title="training cost"
    plt.title(title, fontsize=24)
    plt.xlabel("iter", fontsize=14)
    plt.ylabel("cost", fontsize=14)
    plt.plot(iters, train_costs, color='red', label='training cost') 
    plt.grid()
    plt.show()

EPOCH_NUM=50

#训练EPOCH_NUM轮
for pass_id in range(EPOCH_NUM):                                  
    # 开始训练并输出最后一个batch的损失值
    train_cost = 0

    #遍历train_reader迭代器
    for batch_id, data in enumerate(train_loader()):
        inputs = paddle.to_tensor(data[0])
        labels = paddle.to_tensor(data[1])
        out = net(inputs)
        train_loss = paddle.mean(paddle.nn.functional.square_error_cost(out, labels))
        train_loss.backward()
        optimizer.step()
        optimizer.clear_grad()

        #每运行40步，输出一次信息,
        #注意batch_id=0时也输出, 即 0, 40, 80, ...
        if batch_id % 40 == 0:
            print("Pass:%d, Cost:%0.5f" % (pass_id, train_loss))    
        
        iter = iter + BATCH_SIZE
        iters.append(iter)
        train_costs.append(train_loss.numpy()[0])
       
    # 开始测试并输出最后一个batch的损失值
    test_loss = 0

    #遍历test_reader迭代器
    for batch_id, data in enumerate(valid_loader()):               
        inputs = paddle.to_tensor(data[0])
        labels = paddle.to_tensor(data[1])
        out = net(inputs)
        test_loss = paddle.mean(paddle.nn.functional.square_error_cost(out, labels))
        
    #打印最后一个batch的损失值
    print('Test:%d, Cost:%0.5f' % (pass_id, test_loss))   